{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "73c45d9e",
   "metadata": {},
   "source": [
    "# 4-1,张量的结构操作\n",
    "\n",
    "本篇我们介绍张量的结构操作。\n",
    "\n",
    "张量结构操作主要包括：张量创建，索引切片，维度变换，合并分割。\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2d5140da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.__version__=2.0.1\n"
     ]
    }
   ],
   "source": [
    "import torch \n",
    "print(\"torch.__version__=\"+torch.__version__) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03129563",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ce5e93cd",
   "metadata": {},
   "source": [
    "### 一，创建张量"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19afd9f2",
   "metadata": {},
   "source": [
    "张量创建的许多方法和numpy中创建array的方法很像。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "67963200",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "38784279",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2., 3.])\n"
     ]
    }
   ],
   "source": [
    "a = torch.tensor([1,2,3],dtype = torch.float)\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "089f2763",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1, 3, 5, 7, 9])\n"
     ]
    }
   ],
   "source": [
    "b = torch.arange(1,10,step = 2)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24eadf97",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8a56defe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.0000, 0.6978, 1.3956, 2.0933, 2.7911, 3.4889, 4.1867, 4.8844, 5.5822,\n",
      "        6.2800])\n"
     ]
    }
   ],
   "source": [
    "c = torch.linspace(0.0,2*3.14,10)\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f59296",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "79ef680d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "d = torch.zeros((3,3))\n",
    "print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29402d88",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c6e294e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 1, 1],\n",
      "        [1, 1, 1],\n",
      "        [1, 1, 1]], dtype=torch.int32)\n",
      "tensor([[0., 0., 0.],\n",
      "        [0., 0., 0.],\n",
      "        [0., 0., 0.]])\n"
     ]
    }
   ],
   "source": [
    "a = torch.ones((3,3),dtype = torch.int)\n",
    "b = torch.zeros_like(a,dtype = torch.float)\n",
    "print(a)\n",
    "print(b)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd256014",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "827a9fb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[5., 5., 5.],\n",
      "        [5., 5., 5.],\n",
      "        [5., 5., 5.]])\n"
     ]
    }
   ],
   "source": [
    "torch.fill_(b,5)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "515cf001",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a2b191f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([4.9626, 7.6822, 0.8848, 1.3203, 3.0742])\n"
     ]
    }
   ],
   "source": [
    "#均匀随机分布\n",
    "torch.manual_seed(0)\n",
    "minval,maxval = 0,10\n",
    "a = minval + (maxval-minval)*torch.rand([5])\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "490a81bc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3a0e6228",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.5507,  0.2704,  0.6472],\n",
      "        [ 0.2490, -0.3354,  0.4564],\n",
      "        [-0.6255,  0.4539, -1.3740]])\n"
     ]
    }
   ],
   "source": [
    "#正态分布随机\n",
    "b = torch.normal(mean = torch.zeros(3,3), std = torch.ones(3,3))\n",
    "print(b)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5526dea",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "897c8f4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[16.2371, -1.6612,  3.9163],\n",
      "        [ 7.4999,  1.5616,  4.0768],\n",
      "        [ 5.2128, -8.9407,  6.4601]])\n"
     ]
    }
   ],
   "source": [
    "#正态分布随机\n",
    "mean,std = 2,5\n",
    "c = std*torch.randn((3,3))+mean\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce7c13a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "fceecfdd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 3, 17,  9, 19,  1, 18,  4, 13, 15, 12,  0, 16,  7, 11,  2,  5,  8, 10,\n",
      "         6, 14])\n"
     ]
    }
   ],
   "source": [
    "#整数随机排列\n",
    "d = torch.randperm(20)\n",
    "print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1505be5f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d231f907",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 0., 0.],\n",
      "        [0., 1., 0.],\n",
      "        [0., 0., 1.]])\n",
      "tensor([[1, 0, 0],\n",
      "        [0, 2, 0],\n",
      "        [0, 0, 3]])\n"
     ]
    }
   ],
   "source": [
    "#特殊矩阵\n",
    "I = torch.eye(3,3) #单位矩阵\n",
    "print(I)\n",
    "t = torch.diag(torch.tensor([1,2,3])) #对角矩阵\n",
    "print(t)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8f6d4df",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35a38fa0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f461074e",
   "metadata": {},
   "source": [
    "### 二 ，索引切片"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f460ba5c",
   "metadata": {},
   "source": [
    "张量的索引切片方式和numpy几乎是一样的。切片时支持缺省参数和省略号。\n",
    "\n",
    "可以通过索引和切片对部分元素进行修改。\n",
    "\n",
    "此外，对于不规则的切片提取,可以使用torch.index_select, torch.masked_select, torch.take\n",
    "\n",
    "如果要通过修改张量的某些元素得到新的张量，可以使用torch.where,torch.masked_fill,torch.index_fill"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2d562ecc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[4, 7, 0, 1, 3],\n",
      "        [6, 4, 8, 4, 6],\n",
      "        [3, 4, 0, 1, 2],\n",
      "        [5, 6, 8, 1, 2],\n",
      "        [6, 9, 3, 8, 4]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#均匀随机分布\n",
    "torch.manual_seed(0)\n",
    "minval,maxval = 0,10\n",
    "t = torch.floor(minval + (maxval-minval)*torch.rand([5,5])).int()\n",
    "print(t)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "426d8958",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "cd695542",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([4, 7, 0, 1, 3], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#第0行\n",
    "print(t[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2723331",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7e2d299a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([6, 9, 3, 8, 4], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#倒数第一行\n",
    "print(t[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87148f17",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "83e9c09d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(4, dtype=torch.int32)\n",
      "tensor(4, dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#第1行第3列\n",
    "print(t[1,3])\n",
    "print(t[1][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d479068a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c3bec0e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[6, 4, 8, 4, 6],\n",
      "        [3, 4, 0, 1, 2],\n",
      "        [5, 6, 8, 1, 2]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#第1行至第3行\n",
    "print(t[1:4,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b056c6aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[6, 8],\n",
      "        [3, 0],\n",
      "        [5, 8]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#第1行至最后一行，第0列到最后一列每隔两列取一列\n",
    "print(t[1:4,:4:2])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eeda6705-5298-43d0-a42e-13962516200f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f8989377",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 2.],\n",
       "        [0., 0.]])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#可以使用索引和切片修改部分元素\n",
    "x = torch.Tensor([[1,2],[3,4]])\n",
    "x.data[1,:] = torch.tensor([0.0,0.0])\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37cd24a5-1285-4710-ab1e-2ec3a54609c0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "33abd721",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[ 0,  1,  2],\n",
      "         [ 3,  4,  5],\n",
      "         [ 6,  7,  8]],\n",
      "\n",
      "        [[ 9, 10, 11],\n",
      "         [12, 13, 14],\n",
      "         [15, 16, 17]],\n",
      "\n",
      "        [[18, 19, 20],\n",
      "         [21, 22, 23],\n",
      "         [24, 25, 26]]])\n"
     ]
    }
   ],
   "source": [
    "a = torch.arange(27).view(3,3,3)\n",
    "print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7ddeb6d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1,  4,  7],\n",
      "        [10, 13, 16],\n",
      "        [19, 22, 25]])\n"
     ]
    }
   ],
   "source": [
    "#省略号可以表示多个冒号\n",
    "print(a[...,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba75917f-f20c-4956-991c-d7d5e25abf80",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "047c830a",
   "metadata": {},
   "source": [
    "以上切片方式相对规则，对于不规则的切片提取,可以使用torch.index_select, torch.take, torch.gather, torch.masked_select.\n",
    "\n",
    "考虑班级成绩册的例子，有4个班级，每个班级5个学生，每个学生7门科目成绩。可以用一个4×5×7的张量来表示。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1c8edbc7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[55, 95,  3, 18, 37, 30, 93],\n",
      "         [17, 26, 15,  3, 20, 92, 72],\n",
      "         [74, 52, 24, 58,  3, 13, 24],\n",
      "         [81, 79, 27, 48, 81, 99, 69],\n",
      "         [56, 83, 20, 59, 11, 15, 24]],\n",
      "\n",
      "        [[72, 70, 20, 65, 77, 43, 51],\n",
      "         [61, 81, 98, 11, 31, 69, 91],\n",
      "         [93, 94, 59,  6, 54, 18,  3],\n",
      "         [94, 88,  0, 59, 41, 41, 27],\n",
      "         [69, 20, 68, 75, 85, 68,  0]],\n",
      "\n",
      "        [[17, 74, 60, 10, 21, 97, 83],\n",
      "         [28, 37,  2, 49, 12, 11, 47],\n",
      "         [57, 29, 79, 19, 95, 84,  7],\n",
      "         [37, 52, 57, 61, 69, 52, 25],\n",
      "         [73,  2, 20, 37, 25, 32,  9]],\n",
      "\n",
      "        [[39, 60, 17, 47, 85, 44, 51],\n",
      "         [45, 60, 81, 97, 81, 97, 46],\n",
      "         [ 5, 26, 84, 49, 25, 11,  3],\n",
      "         [ 7, 39, 77, 77,  1, 81, 10],\n",
      "         [39, 29, 40, 40,  5,  6, 42]]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "minval=0\n",
    "maxval=100\n",
    "scores = torch.floor(minval + (maxval-minval)*torch.rand([4,5,7])).int()\n",
    "print(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "3cc30352",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[55, 95,  3, 18, 37, 30, 93],\n",
       "         [74, 52, 24, 58,  3, 13, 24],\n",
       "         [56, 83, 20, 59, 11, 15, 24]],\n",
       "\n",
       "        [[72, 70, 20, 65, 77, 43, 51],\n",
       "         [93, 94, 59,  6, 54, 18,  3],\n",
       "         [69, 20, 68, 75, 85, 68,  0]],\n",
       "\n",
       "        [[17, 74, 60, 10, 21, 97, 83],\n",
       "         [57, 29, 79, 19, 95, 84,  7],\n",
       "         [73,  2, 20, 37, 25, 32,  9]],\n",
       "\n",
       "        [[39, 60, 17, 47, 85, 44, 51],\n",
       "         [ 5, 26, 84, 49, 25, 11,  3],\n",
       "         [39, 29, 40, 40,  5,  6, 42]]], dtype=torch.int32)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#抽取每个班级第0个学生，第2个学生，第4个学生的全部成绩\n",
    "torch.index_select(scores,dim = 1,index = torch.tensor([0,2,4]))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e82bc37a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[95, 18, 93],\n",
      "         [52, 58, 24],\n",
      "         [83, 59, 24]],\n",
      "\n",
      "        [[70, 65, 51],\n",
      "         [94,  6,  3],\n",
      "         [20, 75,  0]],\n",
      "\n",
      "        [[74, 10, 83],\n",
      "         [29, 19,  7],\n",
      "         [ 2, 37,  9]],\n",
      "\n",
      "        [[60, 47, 51],\n",
      "         [26, 49,  3],\n",
      "         [29, 40, 42]]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#抽取每个班级第0个学生，第2个学生，第4个学生的第1门课程，第3门课程，第6门课程成绩\n",
    "q = torch.index_select(torch.index_select(scores,dim = 1,index = torch.tensor([0,2,4]))\n",
    "                   ,dim=2,index = torch.tensor([1,3,6]))\n",
    "print(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "19230d02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([55, 52, 42], dtype=torch.int32)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#抽取第0个班级第0个学生的第0门课程，第2个班级的第3个学生的第1门课程，第3个班级的第4个学生第6门课程成绩\n",
    "#take将输入看成一维数组，输出和index同形状\n",
    "s = torch.take(scores,torch.tensor([0*5*7+0,2*5*7+3*7+1,3*5*7+4*7+6]))\n",
    "s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "d753e409",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([95, 93, 92, 81, 81, 99, 83, 81, 98, 91, 93, 94, 94, 88, 85, 97, 83, 95,\n",
      "        84, 85, 81, 97, 81, 97, 84, 81], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "#抽取分数大于等于80分的分数（布尔索引）\n",
    "#结果是1维张量\n",
    "g = torch.masked_select(scores,scores>=80)\n",
    "print(g)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3266889",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "fd888de5",
   "metadata": {},
   "source": [
    "以上这些方法仅能提取张量的部分元素值，但不能更改张量的部分元素值得到新的张量。\n",
    "\n",
    "如果要通过修改张量的部分元素值得到新的张量，可以使用torch.where,torch.index_fill 和 torch.masked_fill\n",
    "\n",
    "torch.where可以理解为if的张量版本。\n",
    "\n",
    "torch.index_fill的选取元素逻辑和torch.index_select相同。\n",
    "\n",
    "torch.masked_fill的选取元素逻辑和torch.masked_select相同。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "59ebf9da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0, 1, 0, 0, 0, 0, 1],\n",
      "         [0, 0, 0, 0, 0, 1, 1],\n",
      "         [1, 0, 0, 0, 0, 0, 0],\n",
      "         [1, 1, 0, 0, 1, 1, 1],\n",
      "         [0, 1, 0, 0, 0, 0, 0]],\n",
      "\n",
      "        [[1, 1, 0, 1, 1, 0, 0],\n",
      "         [1, 1, 1, 0, 0, 1, 1],\n",
      "         [1, 1, 0, 0, 0, 0, 0],\n",
      "         [1, 1, 0, 0, 0, 0, 0],\n",
      "         [1, 0, 1, 1, 1, 1, 0]],\n",
      "\n",
      "        [[0, 1, 0, 0, 0, 1, 1],\n",
      "         [0, 0, 0, 0, 0, 0, 0],\n",
      "         [0, 0, 1, 0, 1, 1, 0],\n",
      "         [0, 0, 0, 1, 1, 0, 0],\n",
      "         [1, 0, 0, 0, 0, 0, 0]],\n",
      "\n",
      "        [[0, 0, 0, 0, 1, 0, 0],\n",
      "         [0, 0, 1, 1, 1, 1, 0],\n",
      "         [0, 0, 1, 0, 0, 0, 0],\n",
      "         [0, 0, 1, 1, 0, 1, 0],\n",
      "         [0, 0, 0, 0, 0, 0, 0]]])\n"
     ]
    }
   ],
   "source": [
    "#如果分数大于60分，赋值成1，否则赋值成0\n",
    "ifpass = torch.where(scores>60,torch.tensor(1),torch.tensor(0))\n",
    "print(ifpass)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92b72965",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "8daaca71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 17,  26,  15,   3,  20,  92,  72],\n",
       "         [100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 81,  79,  27,  48,  81,  99,  69],\n",
       "         [100, 100, 100, 100, 100, 100, 100]],\n",
       "\n",
       "        [[100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 61,  81,  98,  11,  31,  69,  91],\n",
       "         [100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 94,  88,   0,  59,  41,  41,  27],\n",
       "         [100, 100, 100, 100, 100, 100, 100]],\n",
       "\n",
       "        [[100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 28,  37,   2,  49,  12,  11,  47],\n",
       "         [100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 37,  52,  57,  61,  69,  52,  25],\n",
       "         [100, 100, 100, 100, 100, 100, 100]],\n",
       "\n",
       "        [[100, 100, 100, 100, 100, 100, 100],\n",
       "         [ 45,  60,  81,  97,  81,  97,  46],\n",
       "         [100, 100, 100, 100, 100, 100, 100],\n",
       "         [  7,  39,  77,  77,   1,  81,  10],\n",
       "         [100, 100, 100, 100, 100, 100, 100]]], dtype=torch.int32)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#将每个班级第0个学生，第2个学生，第4个学生的全部成绩赋值成满分\n",
    "torch.index_fill(scores,dim = 1,index = torch.tensor([0,2,4]),value = 100)\n",
    "#等价于 scores.index_fill(dim = 1,index = torch.tensor([0,2,4]),value = 100)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "1cdbf318",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[60, 95, 60, 60, 60, 60, 93],\n",
       "         [60, 60, 60, 60, 60, 92, 72],\n",
       "         [74, 60, 60, 60, 60, 60, 60],\n",
       "         [81, 79, 60, 60, 81, 99, 69],\n",
       "         [60, 83, 60, 60, 60, 60, 60]],\n",
       "\n",
       "        [[72, 70, 60, 65, 77, 60, 60],\n",
       "         [61, 81, 98, 60, 60, 69, 91],\n",
       "         [93, 94, 60, 60, 60, 60, 60],\n",
       "         [94, 88, 60, 60, 60, 60, 60],\n",
       "         [69, 60, 68, 75, 85, 68, 60]],\n",
       "\n",
       "        [[60, 74, 60, 60, 60, 97, 83],\n",
       "         [60, 60, 60, 60, 60, 60, 60],\n",
       "         [60, 60, 79, 60, 95, 84, 60],\n",
       "         [60, 60, 60, 61, 69, 60, 60],\n",
       "         [73, 60, 60, 60, 60, 60, 60]],\n",
       "\n",
       "        [[60, 60, 60, 60, 85, 60, 60],\n",
       "         [60, 60, 81, 97, 81, 97, 60],\n",
       "         [60, 60, 84, 60, 60, 60, 60],\n",
       "         [60, 60, 77, 77, 60, 81, 60],\n",
       "         [60, 60, 60, 60, 60, 60, 60]]], dtype=torch.int32)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#将分数小于60分的分数赋值成60分\n",
    "b = torch.masked_fill(scores,scores<60,60)\n",
    "#等价于b = scores.masked_fill(scores<60,60)\n",
    "b\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7c5b89d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "f840b1ef",
   "metadata": {},
   "source": [
    "### 三，维度变换"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50c75e1d",
   "metadata": {},
   "source": [
    "维度变换相关函数主要有 torch.reshape(或者调用张量的view方法), torch.squeeze, torch.unsqueeze, torch.transpose\n",
    "\n",
    "torch.reshape 可以改变张量的形状。\n",
    "\n",
    "torch.squeeze 可以减少维度。\n",
    "\n",
    "torch.unsqueeze 可以增加维度。\n",
    "\n",
    "torch.transpose/torch.permute 可以交换维度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "6f2eb3fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 3, 3, 2])\n",
      "tensor([[[[126, 195],\n",
      "          [ 22,  33],\n",
      "          [ 78, 161]],\n",
      "\n",
      "         [[124, 228],\n",
      "          [116, 161],\n",
      "          [ 88, 102]],\n",
      "\n",
      "         [[  5,  43],\n",
      "          [ 74, 132],\n",
      "          [177, 204]]]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "# 张量的view方法有时候会调用失败，可以使用reshape方法。\n",
    "\n",
    "torch.manual_seed(0)\n",
    "minval,maxval = 0,255\n",
    "a = (minval + (maxval-minval)*torch.rand([1,3,3,2])).int()\n",
    "print(a.shape)\n",
    "print(a)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "99a96211",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 6])\n",
      "tensor([[126, 195,  22,  33,  78, 161],\n",
      "        [124, 228, 116, 161,  88, 102],\n",
      "        [  5,  43,  74, 132, 177, 204]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "# 改成 （3,6）形状的张量\n",
    "b = a.view([3,6]) #torch.reshape(a,[3,6])\n",
    "print(b.shape)\n",
    "print(b)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "af828000",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[126, 195],\n",
      "          [ 22,  33],\n",
      "          [ 78, 161]],\n",
      "\n",
      "         [[124, 228],\n",
      "          [116, 161],\n",
      "          [ 88, 102]],\n",
      "\n",
      "         [[  5,  43],\n",
      "          [ 74, 132],\n",
      "          [177, 204]]]], dtype=torch.int32)\n"
     ]
    }
   ],
   "source": [
    "# 改回成 [1,3,3,2] 形状的张量\n",
    "c = torch.reshape(b,[1,3,3,2]) # b.view([1,3,3,2]) \n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7d86688",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "619a6692",
   "metadata": {},
   "source": [
    "如果张量在某个维度上只有一个元素，利用torch.squeeze可以消除这个维度。\n",
    "\n",
    "torch.unsqueeze的作用和torch.squeeze的作用相反。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "c5fcad41",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1., 2.]])\n",
      "tensor([1., 2.])\n",
      "torch.Size([1, 2])\n",
      "torch.Size([2])\n"
     ]
    }
   ],
   "source": [
    "a = torch.tensor([[1.0,2.0]])\n",
    "s = torch.squeeze(a)\n",
    "print(a)\n",
    "print(s)\n",
    "print(a.shape)\n",
    "print(s.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "250cab70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1., 2.])\n",
      "tensor([[1., 2.]])\n",
      "torch.Size([2])\n",
      "torch.Size([1, 2])\n"
     ]
    }
   ],
   "source": [
    "#在第0维插入长度为1的一个维度\n",
    "\n",
    "d = torch.unsqueeze(s,axis=0)  \n",
    "print(s)\n",
    "print(d)\n",
    "\n",
    "print(s.shape)\n",
    "print(d.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34d168d1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1b51c3b8",
   "metadata": {},
   "source": [
    "torch.transpose可以交换张量的维度，torch.transpose常用于图片存储格式的变换上。\n",
    "\n",
    "如果是二维的矩阵，通常会调用矩阵的转置方法 matrix.t()，等价于 torch.transpose(matrix,0,1)。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f0214621",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([100, 256, 256, 4])\n",
      "torch.Size([100, 4, 256, 256])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([100, 4, 256, 256])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "minval=0\n",
    "maxval=255\n",
    "# Batch,Height,Width,Channel\n",
    "data = torch.floor(minval + (maxval-minval)*torch.rand([100,256,256,4])).int()\n",
    "print(data.shape)\n",
    "\n",
    "# 转换成 Pytorch默认的图片格式 Batch,Channel,Height,Width \n",
    "# 需要交换两次\n",
    "data_t = torch.transpose(torch.transpose(data,1,2),1,3)\n",
    "print(data_t.shape)\n",
    "\n",
    "\n",
    "data_p = torch.permute(data,[0,3,1,2]) #对维度的顺序做重新编排\n",
    "data_p.shape \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "eb195182",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "217205e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2, 3],\n",
      "        [4, 5, 6]])\n",
      "tensor([[1, 4],\n",
      "        [2, 5],\n",
      "        [3, 6]])\n"
     ]
    }
   ],
   "source": [
    "matrix = torch.tensor([[1,2,3],[4,5,6]])\n",
    "print(matrix)\n",
    "print(matrix.t()) #等价于torch.transpose(matrix,0,1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ea95438",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1af8cd68",
   "metadata": {},
   "source": [
    "### 四，合并分割"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d0768d9",
   "metadata": {},
   "source": [
    "可以用torch.cat方法和torch.stack方法将多个张量合并，可以用torch.split方法把一个张量分割成多个张量。\n",
    "\n",
    "torch.cat和torch.stack有略微的区别，torch.cat是连接，不会增加维度，而torch.stack是堆叠，会增加维度。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "a4d8c45d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([6, 2])\n",
      "tensor([[ 1.,  2.],\n",
      "        [ 3.,  4.],\n",
      "        [ 5.,  6.],\n",
      "        [ 7.,  8.],\n",
      "        [ 9., 10.],\n",
      "        [11., 12.]])\n"
     ]
    }
   ],
   "source": [
    "a = torch.tensor([[1.0,2.0],[3.0,4.0]])\n",
    "b = torch.tensor([[5.0,6.0],[7.0,8.0]])\n",
    "c = torch.tensor([[9.0,10.0],[11.0,12.0]])\n",
    "\n",
    "abc_cat = torch.cat([a,b,c],dim = 0)\n",
    "print(abc_cat.shape)\n",
    "print(abc_cat)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "c5f1b5ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 2, 2])\n",
      "tensor([[[ 1.,  2.],\n",
      "         [ 3.,  4.]],\n",
      "\n",
      "        [[ 5.,  6.],\n",
      "         [ 7.,  8.]],\n",
      "\n",
      "        [[ 9., 10.],\n",
      "         [11., 12.]]])\n"
     ]
    }
   ],
   "source": [
    "abc_stack = torch.stack([a,b,c],axis = 0) #torch中dim和axis参数名可以混用\n",
    "print(abc_stack.shape)\n",
    "print(abc_stack)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "ae893622",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.,  2.,  5.,  6.,  9., 10.],\n",
       "        [ 3.,  4.,  7.,  8., 11., 12.]])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.cat([a,b,c],axis = 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "550ec0d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 1.,  2.],\n",
       "         [ 5.,  6.],\n",
       "         [ 9., 10.]],\n",
       "\n",
       "        [[ 3.,  4.],\n",
       "         [ 7.,  8.],\n",
       "         [11., 12.]]])"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack([a,b,c],axis = 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dadfc66",
   "metadata": {},
   "source": [
    "torch.split是torch.cat的逆运算，可以指定分割份数平均分割，也可以通过指定每份的记录数量进行分割。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "50e02ce3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.,  2.],\n",
      "        [ 3.,  4.],\n",
      "        [ 5.,  6.],\n",
      "        [ 7.,  8.],\n",
      "        [ 9., 10.],\n",
      "        [11., 12.]])\n",
      "tensor([[1., 2.],\n",
      "        [3., 4.]])\n",
      "tensor([[5., 6.],\n",
      "        [7., 8.]])\n",
      "tensor([[ 9., 10.],\n",
      "        [11., 12.]])\n"
     ]
    }
   ],
   "source": [
    "print(abc_cat)\n",
    "a,b,c = torch.split(abc_cat,split_size_or_sections = 2,dim = 0) #每份2个进行分割\n",
    "print(a)\n",
    "print(b)\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "fae5c5f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.,  2.],\n",
      "        [ 3.,  4.],\n",
      "        [ 5.,  6.],\n",
      "        [ 7.,  8.],\n",
      "        [ 9., 10.],\n",
      "        [11., 12.]])\n",
      "tensor([[1., 2.],\n",
      "        [3., 4.],\n",
      "        [5., 6.],\n",
      "        [7., 8.]])\n",
      "tensor([[ 9., 10.]])\n",
      "tensor([[11., 12.]])\n"
     ]
    }
   ],
   "source": [
    "print(abc_cat)\n",
    "p,q,r = torch.split(abc_cat,split_size_or_sections =[4,1,1],dim = 0) #每份分别为[4,1,1]\n",
    "print(p)\n",
    "print(q)\n",
    "print(r)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e57df33b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "10750702",
   "metadata": {},
   "source": [
    "**如果本书对你有所帮助，想鼓励一下作者，记得给本项目加一颗星星star⭐️，并分享给你的朋友们喔😊!** \n",
    "\n",
    "如果对本书内容理解上有需要进一步和作者交流的地方，欢迎在公众号\"算法美食屋\"下留言。作者时间和精力有限，会酌情予以回复。\n",
    "\n",
    "也可以在公众号后台回复关键字：**加群**，加入读者交流群和大家讨论。\n",
    "\n",
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
